from __future__ import annotations

import os
import time
import random
from typing import List, Dict, Tuple, Optional
import requests


class OpenAIStyleClient:
    def __init__(self, model: str, api_base: str | None = None, api_key: str | None = None):
        self.model = model
        self.api_base = api_base or os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
        # Key fallback order: explicit param > OPENAI_API_KEY > OPENROUTER_API_KEY > LIVEBENCH_API_KEY
        self.api_key = (
            api_key
            or os.environ.get("OPENAI_API_KEY")
            or os.environ.get("OPENROUTER_API_KEY")
            or os.environ.get("LIVEBENCH_API_KEY")
        )

    def chat(
        self,
        messages: List[Dict],
        temperature: float,
        max_tokens: int,
        reasoning_max_tokens: Optional[int] = None,
        reasoning_exclude: Optional[bool] = None,
    ) -> Tuple[str, int]:
        url = f"{self.api_base}/chat/completions"
        headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "application/json"}
        # OpenRouter recommends Referer and X-Title for better rate limit handling
        if "openrouter.ai" in (self.api_base or ""):
            referer = os.environ.get("OPENROUTER_REFERER") or os.environ.get("OPENAI_API_REFERRER")
            title = os.environ.get("OPENROUTER_TITLE") or os.environ.get("OPENAI_API_TITLE")
            if referer:
                headers["HTTP-Referer"] = referer
            if title:
                headers["X-Title"] = title
        body = {
            "model": self.model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        # Inject OpenRouter reasoning controls if configured
        try:
            if "openrouter.ai" in (self.api_base or ""):
                # Per-call override takes precedence
                if reasoning_max_tokens is not None:
                    # If zero or negative -> do NOT send reasoning block (provider quirks)
                    if int(reasoning_max_tokens) > 0:
                        body["reasoning"] = {"max_tokens": int(reasoning_max_tokens), "budget": int(reasoning_max_tokens), "exclude": bool(reasoning_exclude)}
                        if reasoning_exclude is not None:
                            body["include_reasoning"] = (not bool(reasoning_exclude))
                    else:
                        # Omit reasoning entirely; optionally request not to include reasoning
                        if reasoning_exclude:
                            body["include_reasoning"] = False
                else:
                    budget = os.environ.get("COTHINKER_OR_REASON_BUDGET")
                    exclude_env = os.environ.get("COTHINKER_OR_REASON_EXCLUDE", "0").lower()
                    exclude = exclude_env in {"1", "true", "yes", "on"}
                    if budget is not None:
                        body["reasoning"] = {"max_tokens": int(budget), "budget": int(budget), "exclude": exclude}
                        body["include_reasoning"] = (not exclude)
        except Exception:
            pass
        # Retry/backoff config
        max_retries = int(os.environ.get("COTHINKER_HTTP_MAX_RETRIES", "4"))
        base_backoff = float(os.environ.get("COTHINKER_HTTP_BACKOFF_BASE", "1.0"))
        timeout_s = float(os.environ.get("COTHINKER_HTTP_TIMEOUT", "240"))
        retry_statuses = {408, 409, 425, 429, 500, 502, 503, 504}

        last_err: Optional[Exception] = None
        made_plain_retry = False
        for attempt in range(max_retries + 1):
            try:
                resp = requests.post(url, headers=headers, json=body, timeout=timeout_s)
                if resp.status_code in retry_statuses:
                    raise requests.HTTPError(f"{resp.status_code} {resp.text[:256]}")
                resp.raise_for_status()
                try:
                    data = resp.json()
                except ValueError as je:
                    # JSON decode failed: treat as retryable; try one-shot plain retry without reasoning if applicable
                    last_err = je
                    if (self.api_base and "openrouter.ai" in self.api_base) and ("reasoning" in body) and not made_plain_retry:
                        plain_body = dict(body)
                        plain_body.pop("reasoning", None)
                        plain_body.pop("include_reasoning", None)
                        made_plain_retry = True
                        try:
                            resp2 = requests.post(url, headers=headers, json=plain_body, timeout=timeout_s)
                            if resp2.status_code not in retry_statuses:
                                resp2.raise_for_status()
                                data2 = resp2.json()
                                choice2 = data2.get("choices", [{}])[0]
                                msg2 = choice2.get("message", {})
                                raw2 = msg2.get("content")
                                content = ""
                                if isinstance(raw2, str):
                                    content = (raw2 or "").strip()
                                elif isinstance(raw2, list):
                                    parts2 = []
                                    for part in raw2:
                                        if isinstance(part, dict):
                                            t2 = part.get("type")
                                            if t2 in {"reasoning", "chain_of_thought", "cot"}:
                                                continue
                                            tx = part.get("text")
                                            if isinstance(tx, str) and tx.strip():
                                                parts2.append(tx.strip())
                                    content = "\n".join(parts2).strip()
                                if not content:
                                    rs2 = msg2.get("reasoning")
                                    if isinstance(rs2, str) and rs2.strip():
                                        content = rs2.strip()
                                usage2 = data2.get("usage", {})
                                num_tokens2 = usage2.get("completion_tokens", usage2.get("total_tokens", 0))
                                try:
                                    details2 = usage2.get("completion_tokens_details", {})
                                    rt2 = details2.get("reasoning_tokens")
                                    if isinstance(num_tokens2, int) and isinstance(rt2, int):
                                        num_tokens2 += rt2
                                except Exception:
                                    pass
                                if content:
                                    return content, int(num_tokens2)
                        except Exception:
                            pass
                    # Log snippet to help diagnose provider hiccups when enabled
                    if os.environ.get("COTHINKER_LOG_RESPONSES", "0").lower() in {"1", "true", "yes", "on"}:
                        try:
                            import sys as _sys
                            _sys.stderr.write(f"[CoThinker] DEBUG decode failed. Status={resp.status_code} Snippet={resp.text[:200]}\n")
                        except Exception:
                            pass
                    if attempt >= max_retries:
                        break
                    delay = (base_backoff * (2 ** attempt)) + random.uniform(0, 0.5)
                    time.sleep(delay)
                    continue
                choice = data.get("choices", [{}])[0]
                msg = choice.get("message", {})
                # Extract content while filtering out reasoning parts if present
                raw_content = msg.get("content")
                content = ""
                if isinstance(raw_content, str):
                    content = raw_content.strip()
                elif isinstance(raw_content, list):
                    try:
                        parts: list[str] = []
                        for part in raw_content:
                            t = part.get("type") if isinstance(part, dict) else None
                            # Skip known reasoning-like channels
                            if t in {"reasoning", "chain_of_thought", "cot"}:
                                continue
                            txt = part.get("text") if isinstance(part, dict) else None
                            if isinstance(txt, str) and txt.strip():
                                parts.append(txt.strip())
                        content = "\n".join(parts).strip()
                    except Exception:
                        # fallback if unexpected structure
                        content = ""
                # Fallback for providers returning reasoning but empty content
                if not content:
                    # Optional: emit raw payload for debugging when allowed
                    if os.environ.get("COTHINKER_LOG_RESPONSES", "0").lower() in {"1", "true", "yes", "on"}:
                        try:
                            import sys as _sys
                            _sys.stderr.write(f"[CoThinker] DEBUG raw message: {msg}\n")
                        except Exception:
                            pass
                    reasoning = msg.get("reasoning")
                    if isinstance(reasoning, str) and reasoning.strip():
                        content = reasoning.strip()
                # One-shot plain retry without reasoning block if content is still empty
                if not content and (self.api_base and "openrouter.ai" in self.api_base) and ("reasoning" in body) and not made_plain_retry:
                    plain_body = dict(body)
                    plain_body.pop("reasoning", None)
                    plain_body.pop("include_reasoning", None)
                    made_plain_retry = True
                    # Direct re-request without counting against backoff loop
                    resp2 = requests.post(url, headers=headers, json=plain_body, timeout=timeout_s)
                    if resp2.status_code not in retry_statuses:
                        try:
                            resp2.raise_for_status()
                            data2 = resp2.json()
                            choice2 = data2.get("choices", [{}])[0]
                            msg2 = choice2.get("message", {})
                            raw2 = msg2.get("content")
                            if isinstance(raw2, str):
                                content = (raw2 or "").strip()
                            elif isinstance(raw2, list):
                                parts2 = []
                                for part in raw2:
                                    if isinstance(part, dict):
                                        t2 = part.get("type")
                                        if t2 in {"reasoning", "chain_of_thought", "cot"}:
                                            continue
                                        tx = part.get("text")
                                        if isinstance(tx, str) and tx.strip():
                                            parts2.append(tx.strip())
                                content = "\n".join(parts2).strip()
                            if not content:
                                rs2 = msg2.get("reasoning")
                                if isinstance(rs2, str) and rs2.strip():
                                    content = rs2.strip()
                            usage = data2.get("usage", {})
                            num_tokens = usage.get("completion_tokens", usage.get("total_tokens", 0))
                            try:
                                details = usage.get("completion_tokens_details", {})
                                rt = details.get("reasoning_tokens")
                                if isinstance(num_tokens, int) and isinstance(rt, int):
                                    num_tokens += rt
                            except Exception:
                                pass
                            if content:
                                return content, int(num_tokens)
                        except Exception:
                            pass
                usage = data.get("usage", {})
                # Count only completion tokens if present; fallback to total_tokens
                num_tokens = usage.get("completion_tokens", usage.get("total_tokens", 0))
                # Add reasoning tokens if provided by provider-specific fields
                try:
                    details = usage.get("completion_tokens_details", {})
                    rt = details.get("reasoning_tokens")
                    if isinstance(num_tokens, int) and isinstance(rt, int):
                        num_tokens += rt
                except Exception:
                    pass
                return content, int(num_tokens)
            except (requests.Timeout, requests.ConnectionError, requests.HTTPError, requests.exceptions.ChunkedEncodingError) as e:
                last_err = e
                if attempt >= max_retries:
                    break
                # Exponential backoff with jitter
                delay = (base_backoff * (2 ** attempt)) + random.uniform(0, 0.5)
                time.sleep(delay)
                continue
            except Exception as e:
                # Non-retryable or unknown error
                last_err = e
                break
        return f"[API_ERROR] {last_err}", 0
